{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build your own tensor type\n",
    "\n",
    "We present here a core concept of the PySyft library. It is the ability to add new custom tensor types that can provide specific behaviors such as encryption or traceability. This feature makes our library very generic and completely open to new innovations in the field of privacy-preserving machine learning.\n",
    "\n",
    "We will go through a very simple example which could be the base for a traceability feature that would keep track of the operations performed on the data in a verifiable way. This new tensor type will log all operations executed on tensors of its type. Let's call this type the CustomLoggingTensor.\n",
    "\n",
    "Authors:\n",
    "- Théo Ryffel - Twitter: [@theoryffel](https://twitter.com/theoryffel)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 0. Preliminaries\n",
    "\n",
    "We use the sandbox that we have already discovered.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Setting up Sandbox...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "import torch as th\n",
    "import syft as sy\n",
    "sy.create_sandbox(globals(), verbose=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's first recall the notions of Torch and Syft tensors. All the object the end user manipulates are torch tensors. This is of course the case when it's a pure torch tensor (ex: `x = th.tensor([1., 2])`), but also when you deal with syft objects, such as the pointer tensor which is a particular case of syft tensor."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Wrapper)>[PointerTensor | me:49975446348 -> bob:27355730204]"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ptr = th.tensor([1., 2]).send(bob)\n",
    "ptr"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The wrapper object you see is actually an empty torch tensor with a child argument which is a PointerTensor:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "isinstance(ptr, th.Tensor)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "syft.generic.pointers.pointer_tensor.PointerTensor"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(ptr.child)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This is also true for more complex objects, where you also see this wrapper at the beginning. You can then have multiple Syft or Torch tensors chained through the `.child` attribute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Wrapper)>FixedPrecisionTensor>[AdditiveSharingTensor]\n",
       "\t-> [PointerTensor | me:97109067924 -> alice:10104773986]\n",
       "\t-> [PointerTensor | me:37015083947 -> bob:93808751958]\n",
       "\t*crypto provider: me*"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = th.tensor([1., 2]).fix_prec().share(alice, bob)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FixedPrecisionTensor>[AdditiveSharingTensor]\n",
       "\t-> [PointerTensor | me:97109067924 -> alice:10104773986]\n",
       "\t-> [PointerTensor | me:37015083947 -> bob:93808751958]\n",
       "\t*crypto provider: me*"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.child"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'alice': [PointerTensor | me:97109067924 -> alice:10104773986],\n",
       " 'bob': [PointerTensor | me:37015083947 -> bob:93808751958]}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x.child.child.child"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Recall that the general behaviour is the following: each time a command in called on the top object, it goes down the chain where it can be modified, it is then executed at the bottom and the result is wrapped back to have exactly the some chain structure, to keep the same properties (such as traceability for example)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "What we're going to do here is to create our own syft Tensor type that we will be able to put in this chain!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. The MVP of the CustomLoggingTensor\n",
    "\n",
    "### 1.1 Declare the class type"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To get started, there isn't much things to do. First, we need to create the tensor class.\n",
    "\n",
    "This is done in `syft/frameworks/torch/tensors/`, choose the folder:\n",
    "- `interpreters` if the functionality you want to build will modify the results or functions, or\n",
    "- `decorators` if the functionality is just ... decorative.\n",
    "\n",
    "> _The `interpreters` / `decorators` might be removed in the future, in which case just put your tensor in `syft/frameworks/torch/tensors/`_\n",
    "\n",
    "Here we'll put it in the decorator folder. Choose a simple **but explicit** name, for now `decorators/custom_logging.py` will be sufficient.\n",
    "\n",
    "Write there the minimal class definition, where our tensor inherits from `AbstractTensor`, an abstraction which gives default behaviours to Syft tensors:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from syft.generic.tensor import AbstractTensor\n",
    "\n",
    "class CustomLoggingTensor(AbstractTensor):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This was quite fast, wasn't it?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 1.2 Add to the hooks\n",
    "\n",
    "#### 1.2.1 Allow imports\n",
    "You now need to declare this type in the imports so that you can use it in real. Add it in the files:\n",
    "```\n",
    "- syft/frameworks/torch/tensors/[decorators|interpreters]/__init__.py\n",
    "- syft/__init__.py\n",
    "```\n",
    "You should now be able to import the tensor type: `from syft import CustomLoggingTensor`\n",
    "\n",
    "#### 1.2.2 Hook tensor to add torch operations\n",
    "In the `syft/frameworks/torch/hook/hook.py` file:\n",
    "- Add an import line at the top\n",
    "- Add the following in the TorchHook `__init__`: `self._hook_syft_tensor_methods(CustomLoggingTensor)`\n",
    "\n",
    "All instances of CustomLoggingTensor now have for example a `.add(...)` method. We should now explain how to use it.\n",
    "\n",
    "#### 1.2.3 Hook tensor to correctly forward arguments to the torch operations\n",
    "\n",
    "In particular we would like that arguments provided as CustomLoggingTensor be unwrapped and replaced with their .child attribute, do go down the chain.\n",
    "\n",
    "In the `syft/generic/frameworks/hook/hook_args.py` file:\n",
    "- Add an import line at the top\n",
    "- Extend the `type_rule` dict with `CustomLoggingTensor: one,` (means that this type of tensors supports (un)wrapping)\n",
    "- Extend the `forward_func` dict with `CustomLoggingTensor: lambda i: i.child,` (explains how to unwrap)\n",
    "- Extend the `backward_func` dict with `CustomLoggingTensor: lambda i, **kwargs: CustomLoggingTensor(**kwargs).on(i, wrap=False),` (explains how to wrap)\n",
    "\n",
    "\n",
    "Et voilà! You can already do many things with your new tensor!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "CustomLoggingTensor>None"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = CustomLoggingTensor()\n",
    "x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Ok this is not super useful, but it comes with a `.on` method which works as follow:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Wrapper)>CustomLoggingTensor>tensor([1., 2.])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = th.tensor([1., 2])\n",
    "x = CustomLoggingTensor().on(x)\n",
    "x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`.on` simply inserts the tensor node into a tensor chain. As we always need to have a torch tensor at the top of the chain, a wrapper was automatically added.\n",
    "\n",
    "### Ready to use!\n",
    "As this point, if you want to have the behaviour desired, you should make the code changes in the repository: integrating the code in the repository allows you to benefit from the hooking functionalities.\n",
    "In particular, after re-instantiating the hook, your `CustomLoggingTensor` should have the methods a pure torch tensor has. \n",
    "\n",
    "> **Make the change and re-run the notebook up to here**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This time, we add the `sy.` meaning the code is from the repo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Wrapper)>CustomLoggingTensor>tensor([1., 2.])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = th.tensor([1., 2])\n",
    "x = sy.CustomLoggingTensor().on(x)\n",
    "x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can do computations on this chain such as `x * 2`, and for example the call `__mul__` made will be forwarded all through the chain down to the last node which is a pure torch tensor, whose value is doubled."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(Wrapper)>CustomLoggingTensor>tensor([2., 4.])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x * 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "If you correctly obtained `(Wrapper)>CustomLoggingTensor>tensor([2., 4.])`, you're all set!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Adding special functionalities"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that you have defined your own tensor type, you should specify it's behaviour, as by default it won't do anything thing special and will just act passively.\n",
    "\n",
    "In this part, we will see how to specify custom functionalities. We'll use for the execution parts the already existing `LoggingTensor` instead of the `CustomLoggingTensor` and highlight which part of code produces which functionalities, so that you can run code in this notebook without reloading the kernel. If you want to practice more, you can report the code changes in the `CustomLoggingTensor` class definition and you'll observe the same behaviours (just reload the notebook each time you perform a modification in the library code)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from syft import LoggingTensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.1 Default behaviour for functions\n",
    "\n",
    "You can add a special functionality each time a (hooked) torch **function** is called on `LoggingTensor`: here we just print the call.\n",
    "\n",
    "Note that this is for functions exclusively and not for methods, but applies for all hooked torch functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomLoggingTensor(AbstractTensor):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "\n",
    "    @classmethod\n",
    "    def on_function_call(cls, command):\n",
    "        \"\"\"\n",
    "        Override this to perform a specific action for each call of a torch\n",
    "        function with arguments containing syft tensors of the class doing\n",
    "        the overloading\n",
    "        \"\"\"\n",
    "        cmd, _, args, kwargs = command\n",
    "        print(\"Default log\", cmd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Default log torch.div\n",
      "Default log torch.nn.functional.celu\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(Wrapper)>LoggingTensor>tensor([1., 2.])"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = th.tensor([1., 2])\n",
    "x = LoggingTensor().on(x)\n",
    "\n",
    "th.div(x, x)\n",
    "th.nn.functional.celu(x) # celu is a variant of the activation function relu(x) = max(0, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "> Note: this `on_function_call` is called by `handle_func_command` which comes from the `AbstractTensor`: it explains how to propagate functions down the chain, and in some cases you might also need to change it."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.2 Overloading torch methods\n",
    "\n",
    "We introduce here an important decorator object which is @overloaded:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "from syft.generic.frameworks.overload import overloaded"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can directly overwrite torch methods like this, where we overload the `.add` method so that we first print that it was called and then forward the call to the .child attributes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomLoggingTensor(AbstractTensor):\n",
    "    def __init__(self, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        \n",
    "    @overloaded.method\n",
    "    def add(self, _self, *args, **kwargs):\n",
    "        print(\"Log method add\")\n",
    "        response = _self.add(*args, **kwargs)\n",
    "        return response"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here is an example of how to use the `@` `overloaded.method` decorator. To see\n",
    "what this decorator do, just look at the next method manual_add: it does\n",
    "exactly the same but without the decorator.\n",
    "\n",
    "Note the subtlety between `self` and `_self`: you should use `_self` and **NOT** `self`. We kept `self` because it can hold useful attributes that you might want to access (for example, for the fixed precision tensor it stores the field size)\n",
    "\n",
    "Here is the version of the add method without the decorator: as you can see\n",
    "it is much more complicated. However you might need sometimes to use it to specify\n",
    "some particular behaviour: so here where to start from if needed!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomLoggingTensor(AbstractTensor):\n",
    "    \n",
    "    # [...]\n",
    "    \n",
    "    def manual_add(self, *args, **kwargs):\n",
    "        # Replace all syft tensor with their child attribute\n",
    "        new_self, new_args, new_kwargs = syft.generic.frameworks.hook.hook_args.hook_method_args(\n",
    "            \"add\", self, args, kwargs\n",
    "        )\n",
    "\n",
    "        print(\"Log method manual_add\")\n",
    "        # Send it to the appropriate class and get the response\n",
    "        response = getattr(new_self, \"add\")(*new_args, **new_kwargs)\n",
    "\n",
    "        # Put back SyftTensor on the tensors found in the response\n",
    "        response = syft.generic.frameworks.hook.hook_args.hook_response(\n",
    "            \"add\", response, wrap_type=type(self)\n",
    "        )\n",
    "        return response"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "They behave exactly the same and print a line when called"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Wrapper)>LoggingTensor>tensor([1., 2.])\n",
      "Log method add\n"
     ]
    }
   ],
   "source": [
    "x = LoggingTensor().on(th.tensor([1., 2]))\n",
    "print(x)\n",
    "\n",
    "r = x.add(x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "_You might want to try to run_ `r = x.manual_add(x)` _but this will fail: if the LoggingTensor which is `x.child` had a_ `.manual_add` _method, this is not the case for the wrapper `x` as torch tensor don't have_ `.manual_add` _by default._"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.3 Overloading torch functions\n",
    "\n",
    "We will still use the @overloaded decorator but now with:\n",
    "\n",
    "```\n",
    "- @overloaded.module\n",
    "- @overloaded.function\n",
    "```\n",
    "\n",
    "What we want to do is to overload \n",
    "\n",
    "```\n",
    "- torch.add\n",
    "- torch.nn.functional.relu\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomLoggingTensor(AbstractTensor):\n",
    "    \n",
    "    # [...] \n",
    "    \n",
    "    @staticmethod\n",
    "    @overloaded.module\n",
    "    def torch(module):\n",
    "        \"\"\"\n",
    "        We use the @overloaded.module to specify we're writing here\n",
    "        a function which should overload the function with the same\n",
    "        name in the <torch> module\n",
    "        :param module: object which stores the overloading functions\n",
    "\n",
    "        Note that we used the @staticmethod decorator as we're in a\n",
    "        class\n",
    "        \"\"\"\n",
    "\n",
    "        def add(x, y):\n",
    "            \"\"\"\n",
    "            You can write the function to overload in the most natural\n",
    "            way, so this will be called whenever you call torch.add on\n",
    "            Logging Tensors, and the x and y you get are also Logging\n",
    "            Tensors, so compared to the @overloaded.method, you see\n",
    "            that the @overloaded.module does not hook the arguments.\n",
    "            \"\"\"\n",
    "            print(\"Log function torch.add\")\n",
    "            return x + y\n",
    "\n",
    "        # Just register it using the module variable\n",
    "        module.add = add\n",
    "\n",
    "        @overloaded.function\n",
    "        def mul(x, y):\n",
    "            \"\"\"\n",
    "            You can also add the @overloaded.function decorator to also\n",
    "            hook arguments, ie all the LoggingTensor are replaced with\n",
    "            their child attribute\n",
    "            \"\"\"\n",
    "            print(\"Log function torch.mul\")\n",
    "            return x * y\n",
    "\n",
    "        # Just register it using the module variable\n",
    "        module.mul = mul\n",
    "\n",
    "        # You can also overload functions in submodules!\n",
    "        @overloaded.module\n",
    "        def nn(module):\n",
    "            \"\"\"\n",
    "            The syntax is the same, so @overloaded.module handles recursion\n",
    "            Note that we don't need to add the @staticmethod decorator\n",
    "            \"\"\"\n",
    "\n",
    "            @overloaded.module\n",
    "            def functional(module):\n",
    "                @overloaded.function\n",
    "                def relu(x):\n",
    "                    print(\"Log function torch.nn.functional.relu\")\n",
    "                    return x * (x > 0).float()\n",
    "\n",
    "                module.relu = relu\n",
    "\n",
    "            module.functional = functional\n",
    "\n",
    "        # Modules should be registered just like functions\n",
    "        module.nn = nn"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note the diffence between `def add` and `def mul`: `def add` doesn't have `@` `overloaded.function` which means that the args inside are not unwrapped: there are CustomLoggingTensors, while in `def mul` they are unwrapped and replaced by the child attributes, so Torch tensors in our case."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Look how it changes compared to 2.1: the behaviour is not much different but now the functions modified are very precisely targetted:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Default log torch.div\n",
      "Log function torch.add\n"
     ]
    }
   ],
   "source": [
    "x = th.tensor([1., 2])\n",
    "x = LoggingTensor().on(x)\n",
    "\n",
    "# Default overloading made in 2.1\n",
    "r = th.div(x, x)\n",
    "\n",
    "# Targetted overloading made in 2.3\n",
    "r = th.add(x, x)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Also, compared to 2.1, we changed the function behaviour: for relu for example instead of running the built-in relu we run `x * (x > 0)`, even if the output is the same. We could have also called inside the native relu if we wanted, provided that we unwrap the args using `@` `overloaded.function`, otherwise we would loop indefinitely.\n",
    "```python\n",
    "@overloaded.module\n",
    "def functional(module):\n",
    "    @overloaded.function\n",
    "    def relu(x):\n",
    "        print(\"Log function torch.nn.functional.relu\")\n",
    "        return torch.nn.functional.relu(x)\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.4 Adding custom tensor attributes\n",
    "\n",
    "Sometimes you need to add special attributes to your Syft Tensor, like the FixedPrecisionTensor which has a field attribute for example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Wrapper)>FixedPrecisionTensor>tensor([1000, 2000])\n",
      "Field: 4611686018427387904\n"
     ]
    }
   ],
   "source": [
    "fp = th.tensor([1., 2]).fix_precision()\n",
    "print(fp)\n",
    "print(\"Field:\", fp.child.field)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Just declare them in the `__init__`, like for example a `log_max_size`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomLoggingTensor(AbstractTensor):\n",
    "    def __init__(self, log_max_size=64, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.log_max_size = log_max_size\n",
    "        \n",
    "    # [...]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make sure this value gets correctly added to the response of an operation, when the chain is rebuilt and that a CustomLoggingTensor is wrapped on top of the result, you should declare `get_class_attributes`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomLoggingTensor(AbstractTensor):\n",
    "    \n",
    "    # [...]\n",
    "    \n",
    "    def get_class_attributes(self):\n",
    "        \"\"\"\n",
    "        Return all elements which defines an instance of a certain class.\n",
    "        \"\"\"\n",
    "        return {\n",
    "            'log_max_size': self.log_max_size\n",
    "        }"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2.5 Sending CustomLoggingTensors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Last thing we love to do, is to sent tensors across workers!\n",
    "\n",
    "\n",
    "To do so, you need to add a serializer and a deserializer to the class:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add these new imports\n",
    "import syft\n",
    "from syft.workers.abstract import AbstractWorker\n",
    "\n",
    "class CustomLoggingTensor(AbstractTensor):\n",
    "    \n",
    "    # [...]\n",
    "    \n",
    "    @staticmethod\n",
    "    def simplify(tensor: \"CustomLoggingTensor\") -> tuple:\n",
    "        \"\"\"Takes the attributes of a CustomLoggingTensor and saves them in a tuple.\n",
    "\n",
    "        Args:\n",
    "            tensor: a CustomLoggingTensor.\n",
    "\n",
    "        Returns:\n",
    "            tuple: a tuple holding the unique attributes of the CustomLoggingTensor.\n",
    "        \"\"\"\n",
    "        chain = None\n",
    "        if hasattr(tensor, \"child\"):\n",
    "            chain = syft.serde._simplify(tensor.child)\n",
    "\n",
    "        return (\n",
    "            syft.serde._simplify(tensor.id),\n",
    "            tensor.log_max_size,\n",
    "            syft.serde._simplify(tensor.tags),\n",
    "            syft.serde._simplify(tensor.description),\n",
    "            chain,\n",
    "        )\n",
    "\n",
    "    @staticmethod\n",
    "    def detail(worker: AbstractWorker, tensor_tuple: tuple) -> \"CustomLoggingTensor\":\n",
    "        \"\"\"\n",
    "            This function reconstructs a CustomLoggingTensor given it's attributes in form of a tuple.\n",
    "            Args:\n",
    "                worker: the worker doing the deserialization\n",
    "                tensor_tuple: a tuple holding the attributes of the CustomLoggingTensor\n",
    "            Returns:\n",
    "                CustomLoggingTensor: a CustomLoggingTensor\n",
    "            Examples:\n",
    "                shared_tensor = detail(data)\n",
    "            \"\"\"\n",
    "\n",
    "        tensor_id, log_max_size, tags, description, chain = tensor_tuple\n",
    "\n",
    "        tensor = CustomLoggingTensor(\n",
    "            owner=worker,\n",
    "            id=syft.serde._detail(worker, tensor_id),\n",
    "            log_max_size=log_max_size,\n",
    "            tags=syft.serde._detail(worker, tags),\n",
    "            description=syft.serde._detail(worker, description),\n",
    "        )\n",
    "\n",
    "        if chain is not None:\n",
    "            chain = syft.serde._detail(worker, chain)\n",
    "            tensor.child = chain\n",
    "\n",
    "        return tensor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And to declare this new tensor to the ser/deser module: in `serde/serde.py`:\n",
    "- Add an import for CustomLoggingTensor\n",
    "- Append CustomLoggingTensor to the `OBJ_SIMPLIFIER_AND_DETAILERS` list\n",
    "\n",
    "Everyting should now work correctly:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(Wrapper)>[PointerTensor | me:59260307020 -> alice:64823483938]\n",
      "(Wrapper)>LoggingTensor>tensor([2., 4.])\n"
     ]
    }
   ],
   "source": [
    "x = th.tensor([1., 2])\n",
    "x = sy.LoggingTensor().on(x)\n",
    "\n",
    "p = x.send(alice)\n",
    "print(p)\n",
    "p2 = p + p\n",
    "x2 = p2.get()\n",
    "print(x2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And here you are, you should now understand all the tools we've builded so that you can easily build new tensor types and focus on their behaviour rather than on their integration in the PySyft library."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Congratulations!!! - Time to Join the Community!\n",
    "\n",
    "Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!\n",
    "\n",
    "### Star PySyft on GitHub\n",
    "\n",
    "The easiest way to help our community is just by starring the Repos! This helps raise awareness of the cool tools we're building.\n",
    "\n",
    "- [Star PySyft](https://github.com/OpenMined/PySyft)\n",
    "\n",
    "### Join our Slack!\n",
    "\n",
    "The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at [http://slack.openmined.org](http://slack.openmined.org)\n",
    "\n",
    "### Join a Code Project!\n",
    "\n",
    "The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for \"Projects\". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more \"one off\" mini-projects by searching for GitHub issues marked \"good first issue\".\n",
    "\n",
    "- [PySyft Projects](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3AProject)\n",
    "- [Good First Issue Tickets](https://github.com/OpenMined/PySyft/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22)\n",
    "\n",
    "### Donate\n",
    "\n",
    "If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!\n",
    "\n",
    "[OpenMined's Open Collective Page](https://opencollective.com/openmined)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
